import gensim
import numpy as np
import jieba
from gensim.models.doc2vec import Doc2Vec,TaggedDocument


def jieba_tokenize(text):
    """
    文本分词
    :param text: 文本
    :return: 分词list
    """
    return jieba.lcut(text)


def get_datasest():
    """
    获取doc2vec文本训练数据集
    :return: 文本分词list，及id
    """
    x_train = []
    for file in open('toutiao_cat_data.txt', encoding='utf8'):
        file = file.split('_!_')
        if len(file) > 3:
            document = TaggededDocument(file[3], tags=[int(file[1])])
            x_train.append(document)
    return x_train


def train(x_train, size=2000, epoch_num=10):
    model_dm = Doc2Vec(x_train, min_count=1, window=3, size=size, sample=1e-3, negative=5, workers=4)
    model_dm.train(x_train, total_examples=model_dm.corpus_count, epochs=epoch_num)
    model_dm.save('model')
    return model_dm


def getVecs(model, corpus, size):
    vecs = [np.array(model.docvecs[z.tags[0]].reshape(1, size)) for z in corpus]
    return np.concatenate(vecs)


def test():
    model_dm = Doc2Vec.load("model")
    test_text = ['想换个', '30', '万左右', '的', '车', '，', '现在', '开科鲁兹', '，', '有', '什么', '好', '推荐', '的', '？']
    inferred_vector_dm = model_dm.infer_vector(test_text)
    sims = model_dm.docvecs.most_similar([inferred_vector_dm], topn=10)
    return sims


if __name__ == '__main__':
    x_train = get_datasest()
    model_dm = train(x_train)

    sims = test()
    for count, sim in sims:
        sentence = x_train[count]
        words = ''
        for word in sentence[0]:
            words = words + word + ' '
        print(words, sim, len(sentence[0]))
